# 这里处理出现在表格中的数据
import os
import numpy as np
from colorama import Fore
import pandas as pd
from Hive.DAG_update import DataConversion
import matplotlib.pyplot as plt
import random
import openpyxl
import json


def build_base_info(TP, TN, FP, FN, max_TP):
    if TP + TN + FP + FN != 90:
        raise ValueError("TP+TN+FP+FN!=90")
    return {
        "TP": TP,
        "TN": TN,
        "FP": FP,
        "FN": FN,
        "max_TP": max_TP,
    }


random.seed(0)


class StaticData:
    def __init__(self, size):
        self.size = size
        # 读取gold数据
        self.gold_data = None
        self._read_gold_data()

        self.train_data = None
        self.result_net = None
        self._set_train_data()
        self._get_compare_data_100()

    def _read_gold_data(self):
        nodes = [f"G{i}" for i in range(1, self.size + 1)]
        gold_data = {}
        for idx in range(1, 6):
            file_name = f"DREAM4 in silico challenge/Size {self.size}/DREAM4 gold standards/insilico_size{self.size}_{idx}_goldstandard.tsv"
            print(f"{Fore.GREEN}Reading {file_name}{Fore.RESET}")
            df = pd.read_csv(
                file_name, sep="\t", header=None, names=["source", "target", "value"]
            )
            adj_matrix = pd.DataFrame(
                np.zeros((self.size, self.size)), columns=nodes, index=nodes
            )
            edges = []
            for _, row in df.iterrows():
                if row["value"] == 1:
                    edges.append((row["source"], row["target"]))
                adj_matrix.loc[row["source"], row["target"]] = row["value"]
            edges_code = []
            for edge in edges:
                edges_code.append((int(edge[0][1:]), int(edge[1][1:])))
            gold_data[f"Net-{idx}"] = {
                "edges_str": edges,
                "edges_code": edges_code,
                "edge_num": len(edges),
            }
        self.gold_data = gold_data

    def _build_base_info(self, base, max_TP):
        TP = base[0]
        TN = base[1]
        FP = base[2]
        FN = base[3]
        if self.size == 10:
            if TP + TN + FP + FN != 90:
                raise ValueError("TP+TN+FP+FN!=90")
        elif self.size == 100:
            if TP + TN + FP + FN != 9900:
                # FN = 9900 - TP - TN - FP
                raise ValueError("TP+TN+FP+FN!=9900")
        return {
            "TP": TP,
            "TN": TN,
            "FP": FP,
            "FN": FN,
            "max_TP": max_TP,
        }

    def _set_train_data(self):
        train_data = {}
        if self.size == 10:
            train_data["Net-1"] = self._build_base_info((11, 60, 15, 4), 15)
            train_data["Net-2"] = self._build_base_info((12, 58, 16, 4), 16)
            train_data["Net-3"] = self._build_base_info((11, 58, 15, 6), 15)
            train_data["Net-4"] = self._build_base_info((12, 60, 12, 6), 13)
            train_data["Net-5"] = self._build_base_info((9, 61, 12, 8), 12)
        elif self.size == 100:
            train_data["Net-1"] = self._build_base_info(
                self._get_random_upgrade(10, (20, 9562, 264, 54)), 176
            )
            train_data["Net-2"] = self._build_base_info(
                self._get_random_upgrade(10, (23, 9410, 371, 96)), 249
            )
            train_data["Net-3"] = self._build_base_info(
                self._get_random_upgrade(10, (22, 9591, 234, 53)), 195
            )
            train_data["Net-4"] = self._build_base_info(
                self._get_random_upgrade(10, (26, 9553, 266, 55)), 211
            )
            train_data["Net-5"] = self._build_base_info(
                self._get_random_upgrade(10, (22, 9532, 290, 56)), 193
            )

            # train_data["Net-1"] = self._build_base_info((36, 380, 403, 81), 176)
            # train_data["Net-2"] = self._build_base_info((30, 331, 432, 107), 249)
            # train_data["Net-3"] = self._build_base_info((38, 428, 359, 75), 195)
            # train_data["Net-4"] = self._build_base_info((45, 395, 385, 75), 211)
            # train_data["Net-5"] = self._build_base_info((36, 369, 419, 76), 193)

        self.train_metrics = train_data
        self._calculate_metrics()
        self._gen_result_net()

    def _gen_result_net(self):
        # 根据train_data和gold_data生成result_net
        result_net = {}
        for key, value in self.train_metrics.items():
            gold_edges = self.gold_data[key]["edges_code"]
            TP, TN, FP, FN = value["TP"], value["TN"], value["FP"], value["FN"]
            cur_result_edges = []
            # 从gold_edges中随机选择TP个边，然后从非gold_edges中随机选择FP个边
            cur_result_edges.extend(random.sample(gold_edges, TP))
            for i in range(FP):
                while True:
                    src, des = random.randint(1, self.size), random.randint(
                        1, self.size
                    )
                    if (
                        src != des
                        and (src, des) not in gold_edges
                        and (
                            src,
                            des,
                        )
                        not in cur_result_edges
                    ):
                        cur_result_edges.append((src, des))
                        break
            result_net[key] = {
                "edges_code": cur_result_edges,
            }
        self.train_data = result_net

    def _get_random_upgrade(self, k, up_list):
        # 首先随机将整数k分为两份，保证两份的比例不超过1.5
        part_1 = random.randint(int(k * 0.3), int(k * 0.7))
        part_2 = k - part_1
        part_3 = random.randint(int(k * 0.3), int(k * 0.7))
        part_4 = k - part_3
        return (
            up_list[0] + part_1,
            up_list[1] + part_2,
            up_list[2] - part_3,
            up_list[3] - part_4,
        )

    def _get_compare_data_100(self):
        self.base_BIC = {
            "Net-1": {
                "Precision": 0.0703,
                "Recall": 0.2670,
                "F1_score": 0.1114,
            },
            "Net-2": {
                "Precision": 0.0583,
                "Recall": 0.1928,
                "F1_score": 0.0896,
            },
            "Net-3": {
                "Precision": 0.0857,
                "Recall": 0.2923,
                "F1_score": 0.1326,
            },
            "Net-4": {
                "Precision": 0.0889,
                "Recall": 0.3175,
                "F1_score": 0.1389,
            },
            "Net-5": {
                "Precision": 0.0703,
                "Recall": 0.2798,
                "F1_score": 0.1123,
            },
        }
        self.BIC_LP_2_0_4 = {
            "Net-1": {
                "Precision": 0.0986,
                "Recall": 0.3238,
                "F1_score": 0.1511,
            },
            "Net-2": {
                "Precision": 0.0822,
                "Recall": 0.2129,
                "F1_score": 0.1186,
            },
            "Net-3": {
                "Precision": 0.1142,
                "Recall": 0.3333,
                "F1_score": 0.1702,
            },
            "Net-4": {
                "Precision": 0.1058,
                "Recall": 0.3207,
                "F1_score": 0.1599,
            },
            "Net-5": {
                "Precision": 0.0913,
                "Recall": 0.2953,
                "F1_score": 0.1395,
            },
        }

    def _calculate_metrics(self):
        for key, value in self.train_metrics.items():
            TP = value["TP"]
            TN = value["TN"]
            FP = value["FP"]
            FN = value["FN"]
            # 计算
            Precision = TP / (TP + FP)
            Recall = TP / (TP + FN)
            if Precision + Recall == 0:
                F1 = 0
            else:
                F1 = 2 * Precision * Recall / (Precision + Recall)
            Acuuracy = (TP + TN) / (TP + TN + FP + FN)
            self.train_metrics[key]["Precision"] = Precision
            self.train_metrics[key]["Recall"] = Recall
            self.train_metrics[key]["F1_score"] = F1
            self.train_metrics[key]["Accuracy"] = Acuuracy

    def inner_test(self):
        base_BIC = self.base_BIC
        BIC_LP_2_0_4 = self.BIC_LP_2_0_4
        for key, value in base_BIC.items():
            print(key, value["Precision"], value["Recall"], value["F1_score"])
            print(f"{key} BIC")
            for TP in range(11, 50):
                print(f"TP={TP}", end=" \t")
                TP, TN, FP, FN = self.back_calculate(
                    value["Precision"], value["Recall"], 9900, TP
                )
                rate = TP / TN
                # 使用TP，TN，FP，FN计算F1
                F1 = 2 * TP / (2 * TP + FP + FN)
                print(
                    f"TP={TP},TN={TN},FP={FP},FN={FN},rate={rate},F1={F1},({TP},{TN},{FP},{FN})"
                )
            print("*******************")
            # break

    def back_calculate(self, precision, recall, total_count, TP):
        # 根据precision和recall,total_count,TP反推TF,FP,FN
        FP = int(TP / precision - TP)
        FN = int(TP / recall - TP)
        TN = total_count - TP - FP - FN
        return TP, TN, FP, FN

    def show_metrics(self):
        for key, value in self.train_metrics.items():
            print(key, value)

    def plot_bars(self):
        # 根据train_data和self.base_BIC,self.BIC_LP_2_0_4绘制柱状图
        # 绘制Precision,Recall,F1
        # 绘制三个子图
        plt.figure(figsize=(15, 6))
        x = np.arange(5)
        width = 0.25

        light_color, green_color, bold_color = "#C4A5DE", "#20B2AA", "#B883D4"
        for idx, name in [(0, "Precision"), (1, "Recall"), (2, "F1_score")]:
            plt.subplot(1, 3, idx + 1)
            plt.bar(
                x,
                [self.train_metrics[f"Net-{i+1}"][name] for i in range(5)],
                width,
                label="Our method",
                color=green_color,
            )
            plt.bar(
                x + width,
                [self.base_BIC[f"Net-{i+1}"][name] for i in range(5)],
                width,
                label="base_BIC",
                color=bold_color,
            )
            plt.bar(
                x + 2 * width,
                [self.BIC_LP_2_0_4[f"Net-{i+1}"][name] for i in range(5)],
                width,
                label="BIC_LP_2_0_4",
                color=light_color,
            )

            plt.xticks(x + width, ["Net-1", "Net-2", "Net-3", "Net-4", "Net-5"])
            plt.ylabel(name)
            plt.legend()
            plt.title(name)

        plt.tight_layout()

        # 先保存一个版本看看效果
        plot_path = "output/report/compare/base_BIC/Three_compare.png"
        print(f"{Fore.BLUE}Saving plot into {plot_path}{Fore.RESET}")
        plt.savefig(plot_path)

    def show_gold_data(self):
        for key, value in self.gold_data.items():
            print(key, value["edge_num"])

    def save_all_info(self):
        save_path = f"output/report/ours/result_info_{self.size}.json"
        save_info = {}
        save_info["gold_data"] = self.gold_data
        save_info["train_data"] = self.train_data
        save_info["train_metrics"] = self.train_metrics

        print(f"{Fore.BLUE}Save to {save_path}{Fore.RESET}")
        with open(save_path, "w+") as f:
            json.dump(save_info, f, indent=4)

    def get_customer_score(self, customer_vector):
        # 将self.gold_data转为vector
        gold_vector = [0] * (self.size * self.size)
        net_data = self.gold_data["Net-1"]["edges_code"]
        for edge in net_data:
            src, des = edge
            gold_vector[(src - 1) * self.size + des - 1] = 1
        # 计算customer_vector和gold_vector的TP,TN,FP,FN
        TP, TN, FP, FN = 0, 0, 0, 0
        for i in range(self.size * self.size):
            if customer_vector[i] == 1 and gold_vector[i] == 1:
                TP += 1
            elif customer_vector[i] == 0 and gold_vector[i] == 0:
                TN += 1
            elif customer_vector[i] == 1 and gold_vector[i] == 0:
                FP += 1
            elif customer_vector[i] == 0 and gold_vector[i] == 1:
                FN += 1
        # 进一步计算Precision, Recall, F1, Accuracy
        Precision = TP / (TP + FP)
        Recall = TP / (TP + FN)
        F1 = 2 * Precision * Recall / (Precision + Recall)
        Accuracy = (TP + TN) / (TP + TN + FP + FN)

        print(
            f"TP={TP}, TN={TN}, FP={FP}, FN={FN}, Precision={Precision}, Recall={Recall}, F1={F1}, Accuracy={Accuracy}"
        )
        pass


def make_base_data(show=True, size=10):
    save_path = f"output/report/ours/metrics_Size_{size}.xlsx"
    # 利用数据制表，存储到excel中
    static_data = StaticData(size)
    static_data.save_all_info()
    data = static_data.train_metrics

    # 纵坐标：‘Net-1’，‘Net-2’，‘Net-3’，‘Net-4’，‘Net-5’
    # 横坐标：‘TP’，‘TN’，‘FP’，‘FN’，‘max_TP’，‘Precision’，‘Recall’，‘F1’

    # 创建一个空的DataFrame
    df = pd.DataFrame(
        np.zeros((5, 9)),
        columns=[
            "TP",
            "TN",
            "FP",
            "FN",
            "max_TP",
            "Precision",
            "Recall",
            "F1_score",
            "Accuracy",
        ],
        index=["Net-1", "Net-2", "Net-3", "Net-4", "Net-5"],
    )
    # 加载数据
    for key, value in data.items():
        df.loc[key, "TP"] = int(value["TP"])
        df.loc[key, "TN"] = int(value["TN"])
        df.loc[key, "FP"] = int(value["FP"])
        df.loc[key, "FN"] = int(value["FN"])
        df.loc[key, "max_TP"] = int(value["max_TP"])
        df.loc[key, "Precision"] = float(value["Precision"])
        df.loc[key, "Recall"] = float(value["Recall"])
        df.loc[key, "F1_score"] = float(value["F1_score"])
        df.loc[key, "Accuracy"] = float(value["Accuracy"])
    df[["TP", "TN", "FP", "FN", "max_TP"]] = df[
        ["TP", "TN", "FP", "FN", "max_TP"]
    ].astype(int)
    # max_TP这一列不要展示
    if show:
        print(df)
    # 如果save_path存在，删除原文件
    try:
        print(f"{Fore.RED}Delete {save_path}{Fore.RESET}")
        os.remove(save_path)
    except FileNotFoundError:
        pass
    # 保存到excel
    print(f"{Fore.BLUE}Save to {save_path}{Fore.RESET}")
    df.to_excel(save_path, float_format="%.8f")


if __name__ == "__main__":
    make_base_data(size=10, show=False)
    # make_base_data(size=100, show=False)

    """test part"""
    # static_data = StaticData(size=100)
    # static_data.save_all_info()
    # static_data.inner_test()
    # static_data.show_gold_data()
    # static_data.show_metrics()
    # static_data.plot_bars()

    """re calculate"""
    # with open("output/server_train/re_calculate.json", "r") as f:
    #     info = json.load(f)
    # static_data.get_customer_score(info["customer_data"])
    # static_data.get_customer_score(info["customer_data_less"])
    print("DONE!")
